# import torch


# def log(dist):
#     dist[dist<np.exp(1)] = 